{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# 外部张量函数\n",
        "**原作者**: [Tianqi Chen](https://tqchen.github.io)\n",
        "\n",
        "虽然 TVM 支持透明的代码生成，但有时将手动编写的代码合并到 pipeline 中也是有帮助的。例如，可能想要使用 cuDNN 来处理一些卷积核，并定义其余的阶段。\n",
        "\n",
        "TVM本身就支持这些黑盒函数调用。具体来说，TVM 支持所有与 DLPack 兼容的张量函数。这意味着可以调用任何带有 POD 类型（pointer、int、float）或指向 DLTensor 的指针作为参数的函数。\n",
        "\n",
        "````{note}\n",
        "这里需要设定 cmake：\n",
        "\n",
        "```cmake\n",
        "# Whether use BLAS, choices: openblas, atlas, apple\n",
        "set(USE_BLAS atlas)\n",
        "```\n",
        "````"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import env # 加载 TVM 环境\n",
        "import tvm\n",
        "from tvm import te\n",
        "import numpy as np\n",
        "from tvm.contrib import cblas\n",
        "import tvm.testing\n",
        "\n",
        "if not tvm.get_global_func(\"tvm.contrib.cblas.matmul\", allow_missing=True):\n",
        "    raise Exception(\"Not compiled with cblas support; can't build this tutorial\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 使用 Extern 张量函数\n",
        "\n",
        "在下面的例子中，使用 {any}`te.extern` 添加外部数组函数调用。在 extern 调用中，声明输出张量的形状。在第二个参数中，提供了输入列表。\n",
        "\n",
        "用户需要提供描述如何计算结果的函数。`compute` 函数接受输入的符号占位符列表和输出的符号占位符列表，并返回正在执行的语句。\n",
        "\n",
        "在本例中，只需调用已注册的 TVM 函数，该函数调用 CBLAS 回调。TVM 不控制 extern 数组函数的内部，并将其视为黑盒。可以进一步混合可调度的 TVM 调用，为结果添加 bias 项。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "n = 1024\n",
        "l = 128\n",
        "m = 235\n",
        "bias = te.var(\"bias\", dtype=\"float32\")\n",
        "A = te.placeholder((n, l), name=\"A\")\n",
        "B = te.placeholder((l, m), name=\"B\")\n",
        "C = te.extern(\n",
        "    (n, m),\n",
        "    [A, B],\n",
        "    lambda ins, outs: tvm.tir.call_packed(\n",
        "        \"tvm.contrib.cblas.matmul\", ins[0], ins[1], outs[0], False, False\n",
        "    ),\n",
        "    name=\"C\",\n",
        ")\n",
        "D = te.compute(C.shape, lambda i, j: C[i, j] + bias, name=\"D\")\n",
        "s = te.create_schedule(D.op)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 验证结果\n",
        "\n",
        "可以验证结果是否符合期望。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "dev = tvm.cpu(0)\n",
        "f = tvm.build(s, [A, B, D, bias], \"llvm\")\n",
        "a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), dev)\n",
        "b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), dev)\n",
        "d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), dev)\n",
        "bb = 10.0\n",
        "f(a, b, d, bb)\n",
        "tvm.testing.assert_allclose(d.numpy(), np.dot(a.numpy(), b.numpy()) + 10, rtol=1e-5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 封装 Extern Contrib\n",
        "\n",
        "TVM 还为有用的 extern 调用提供了 extern contrib 封装器，下面一行与前面的示例等价。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from tvm.contrib import cblas\n",
        "\n",
        "C = cblas.matmul(A, B)\n",
        "D = te.compute(C.shape, lambda i, j: C[i, j] + bias, name=\"D\")\n",
        "s = te.create_schedule(D.op)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Hook Python 函数作为 Extern\n",
        "\n",
        "因为可以在 TVM 中调用任何 PackedFunc。所以可以使用 extern 函数回调到 python。\n",
        "\n",
        "下面的例子注册了 python 函数到 TVM 运行时系统，并使用它来完成计算的一个阶段。这使得 TVM 更加灵活。例如，可以插入前端回调来检查中间结果，或者将定制代码与 TVM 混合使用。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "my_tvm_addone signatures: <class 'tvm.runtime.ndarray.NDArray'>, <class 'tvm.runtime.ndarray.NDArray'>\n"
          ]
        }
      ],
      "source": [
        "@tvm.register_func(\"tvm.contrib.my_tvm_addone\")\n",
        "def my_tvm_addone(x, y):\n",
        "    print(f\"my_tvm_addone signatures: {type(x)}, {type(y)}\")\n",
        "    tvm.nd.array(x.numpy() + 1).copyto(y)\n",
        "\n",
        "\n",
        "A = te.placeholder((n,), name=\"A\")\n",
        "B = te.extern(\n",
        "    A.shape,\n",
        "    [A],\n",
        "    lambda ins, outs: tvm.tir.call_packed(\"tvm.contrib.my_tvm_addone\", ins[0], outs[0]),\n",
        "    name=\"C\",\n",
        ")\n",
        "s = te.create_schedule(B.op)\n",
        "f = tvm.build(s, [A, B], \"llvm\")\n",
        "a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), dev)\n",
        "b = tvm.nd.array(np.random.uniform(size=(n,)).astype(B.dtype), dev)\n",
        "f(a, b)\n",
        "tvm.testing.assert_allclose(b.numpy(), a.numpy() + 1, rtol=1e-5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 小结\n",
        "\n",
        "- TVM 通过 {any}`te.extern` 调用外部张量函数\n",
        "- 使用 contrib 包装的外部张量调用的简短语法糖。\n",
        "- 可以将前端函数 hook 为外部张量回调函数。"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3.8.13 ('xc': conda)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    },
    "vscode": {
      "interpreter": {
        "hash": "f4772b2d9fb5f4e213cea28dc6a0e63daacdc3e8a701d5a5063e88b8cfe3308a"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
